// Copyright Epic Games, Inc. All Rights Reserved.

#include <zenstore/caslog.h>

#include <zencore/except.h>
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/memory.h>
#include <zencore/string.h>

#include <xxhash.h>

#include <gsl/gsl-lite.hpp>

#include <filesystem>
#include <functional>

//////////////////////////////////////////////////////////////////////////

namespace zen {

uint32_t
CasLogFile::FileHeader::ComputeChecksum()
{
	return XXH32(&this->Magic, sizeof(FileHeader) - 4, 0xC0C0'BABA);
}

CasLogFile::CasLogFile()
{
}

CasLogFile::~CasLogFile()
{
}

bool
CasLogFile::IsValid(std::filesystem::path FileName, size_t RecordSize)
{
	if (!std::filesystem::is_regular_file(FileName))
	{
		return false;
	}
	BasicFile File;

	std::error_code Ec;
	File.Open(FileName, BasicFile::Mode::kRead, Ec);
	if (Ec)
	{
		return false;
	}

	FileHeader Header;
	if (File.FileSize() < sizeof(Header))
	{
		return false;
	}

	// Validate header and log contents and prepare for appending/replay
	File.Read(&Header, sizeof Header, 0);

	if ((0 != memcmp(Header.Magic, FileHeader::MagicSequence, sizeof Header.Magic)) || (Header.Checksum != Header.ComputeChecksum()))
	{
		return false;
	}
	if (Header.RecordSize != RecordSize)
	{
		return false;
	}
	return true;
}

void
CasLogFile::Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode)
{
	m_RecordSize = RecordSize;

	std::error_code Ec;
	BasicFile::Mode FileMode = BasicFile::Mode::kRead;
	switch (Mode)
	{
		case Mode::kWrite:
			FileMode = BasicFile::Mode::kWrite;
			break;
		case Mode::kTruncate:
			FileMode = BasicFile::Mode::kTruncate;
			break;
	}

	m_File.Open(FileName, FileMode, Ec);
	if (Ec)
	{
		throw std::system_error(Ec, fmt::format("Failed to open log file '{}'", FileName));
	}

	uint64_t AppendOffset = 0;

	if ((Mode == Mode::kTruncate) || (m_File.FileSize() < sizeof(FileHeader)))
	{
		if (Mode == Mode::kRead)
		{
			throw std::runtime_error(fmt::format("Mangled log header (file to small) in '{}'", FileName));
		}
		// Initialize log by writing header
		FileHeader Header = {.RecordSize = gsl::narrow<uint32_t>(RecordSize), .LogId = Oid::NewOid(), .ValidatedTail = 0};
		memcpy(Header.Magic, FileHeader::MagicSequence, sizeof Header.Magic);
		Header.Finalize();

		m_File.Write(&Header, sizeof Header, 0);

		AppendOffset = sizeof(FileHeader);

		m_Header = Header;
	}
	else
	{
		FileHeader Header;
		m_File.Read(&Header, sizeof Header, 0);

		if ((0 != memcmp(Header.Magic, FileHeader::MagicSequence, sizeof Header.Magic)) || (Header.Checksum != Header.ComputeChecksum()))
		{
			throw std::runtime_error(fmt::format("Mangled log header (invalid header magic) in '{}'", FileName));
		}
		if (Header.RecordSize != RecordSize)
		{
			throw std::runtime_error(fmt::format("Mangled log header (mismatch in record size, expected {}, found {}) in '{}'",
												 RecordSize,
												 Header.RecordSize,
												 FileName));
		}

		AppendOffset = m_File.FileSize();

		// Adjust the offset to ensure we end up on a good boundary, in case there is some garbage appended

		AppendOffset -= sizeof Header;
		AppendOffset -= AppendOffset % RecordSize;
		AppendOffset += sizeof Header;

		m_Header = Header;
	}

	m_AppendOffset = AppendOffset;
}

void
CasLogFile::Close()
{
	// TODO: update header and maybe add trailer
	Flush();

	m_File.Close();
}

uint64_t
CasLogFile::GetLogSize()
{
	return m_File.FileSize();
}

uint64_t
CasLogFile::GetLogCount()
{
	uint64_t LogFileSize = m_AppendOffset.load(std::memory_order_acquire);
	if (LogFileSize < sizeof(FileHeader))
	{
		return 0;
	}
	const uint64_t LogBaseOffset = sizeof(FileHeader);
	const size_t   LogEntryCount = (LogFileSize - LogBaseOffset) / m_RecordSize;
	return LogEntryCount;
}

void
CasLogFile::Replay(std::function<void(const void*)>&& Handler, uint64_t SkipEntryCount)
{
	uint64_t LogFileSize = m_File.FileSize();

	// Ensure we end up on a clean boundary
	uint64_t LogBaseOffset = sizeof(FileHeader);
	size_t	 LogEntryCount = (LogFileSize - LogBaseOffset) / m_RecordSize;

	if (LogEntryCount <= SkipEntryCount)
	{
		return;
	}

	LogBaseOffset += SkipEntryCount * m_RecordSize;
	LogEntryCount -= SkipEntryCount;

	const uint64_t LogDataSize	 = LogEntryCount * m_RecordSize;
	uint64_t	   LogDataRemain = LogDataSize;

	const uint64_t MaxBufferSize = 1024 * 1024;

	std::vector<uint8_t> ReadBuffer;
	ReadBuffer.resize((Min(LogDataSize, MaxBufferSize) / m_RecordSize) * m_RecordSize);

	uint64_t ReadOffset = 0;

	while (LogDataRemain)
	{
		const uint64_t BytesToRead	 = Min(ReadBuffer.size(), LogDataRemain);
		const uint64_t EntriesToRead = BytesToRead / m_RecordSize;

		m_File.Read(ReadBuffer.data(), BytesToRead, LogBaseOffset + ReadOffset);

		for (int i = 0; i < int(EntriesToRead); ++i)
		{
			Handler(ReadBuffer.data() + (i * m_RecordSize));
		}

		LogDataRemain -= BytesToRead;
		ReadOffset += BytesToRead;
	}

	m_AppendOffset = LogBaseOffset + (m_RecordSize * LogEntryCount);
}

void
CasLogFile::Append(const void* DataPointer, uint64_t DataSize)
{
	ZEN_ASSERT((DataSize % m_RecordSize) == 0);

	uint64_t AppendOffset = m_AppendOffset.fetch_add(DataSize);

	std::error_code Ec;
	m_File.Write(DataPointer, gsl::narrow<uint32_t>(DataSize), AppendOffset, Ec);

	if (Ec)
	{
		std::error_code Dummy;
		throw std::system_error(Ec, fmt::format("Failed to write to log file '{}'", PathFromHandle(m_File.Handle(), Dummy)));
	}
}

void
CasLogFile::Flush()
{
	m_File.Flush();
}

}  // namespace zen
